/**
 * Copyright 2024-2024 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "model_manager_ndk_impl.h"

#include <unistd.h>
#include <algorithm>

#include "infra/base/assertion.h"
#include "infra/base/securestl.h"

#include "framework/infra/log/log.h"

#include "model/built_model/built_model_ndk_impl.h"
#include "tensor/core/nd_tensor_buffer_ndk_impl.h"

#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_options.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_nncore.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_tensor.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_executor.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_create_itf.h"
#include "model_manager/general_model_manager/ndk/hiai_ndk/hiai_ndk_nativehandle.h"

#ifdef AI_SUPPORT_AIPP_API
#include "model/aipp/aipp_input_converter.h"
#include "tensor/core/aipp_para_ndk_impl.h"
#endif

namespace hiai {
ModelManagerNDKImpl::~ModelManagerNDKImpl()
{
    if (executor_ != nullptr) {
        UnLoad();
    }
}

Status ModelManagerNDKImpl::PrepareModelManagerListener(const std::shared_ptr<IModelManagerListener> &listener)
{
    if (listener != nullptr) {
        std::lock_guard<std::mutex> lock(listenerMutex_);

        listener_ = listener;

        NN_OnRunDone onRunDone = ModelManagerNDKImpl::OnRunDone;
        Status status = HIAI_NDK_NNExecutor_SetOnRunDone(executor_.get(), onRunDone);
        HIAI_EXPECT_TRUE(status == SUCCESS);

        NN_OnServiceDied onServiceDied = ModelManagerNDKImpl::OnServiceDied;
        status = HIAI_NDK_NNExecutor_SetOnServiceDied(executor_.get(), onServiceDied);
        HIAI_EXPECT_TRUE(status == SUCCESS);
    }
    return SUCCESS;
}

namespace {
Status SetInitOption(OH_NNCompilation* compilation, const ModelInitOptions &options, bool isAsyncMode)
{
    Status status = HIAI_NDK_HiAIOptions_SetBandMode(compilation, static_cast<HiAI_BandMode>(options.bandMode));
    HIAI_EXPECT_TRUE(status == SUCCESS);

    uint32_t perfMode = static_cast<uint32_t>(options.perfMode);
    if (HIAI_NDK_HiAIOptions_GetBandMode(compilation) != static_cast<HiAI_BandMode>(options.bandMode)) {
        perfMode = static_cast<uint32_t>(options.bandMode) == 0 ? perfMode :
            static_cast<uint32_t>(options.bandMode) * 100 + perfMode % 100;
    }

    status = HIAI_NDK_NNCompilation_SetPerformanceMode(compilation, static_cast<OH_NN_PerformanceMode>(perfMode));
    HIAI_EXPECT_TRUE(status == SUCCESS);

    status = HIAI_NDK_HiAIOptions_SetAsyncModeEnable(compilation, isAsyncMode);
    HIAI_EXPECT_TRUE(status == SUCCESS);
    return SUCCESS;
}

bool IsOptionsChanged(bool isAsyncMode, const ModelInitOptions &options)
{
    if (isAsyncMode || options.perfMode != PerfMode::MIDDLE || options.bandMode != PerfMode::UNSET) {
        FMK_LOGI("isAsync or options has changed, need reload");
        return true;
    }
    FMK_LOGI("don't need reload");
    return false;
}
}  // namespace

Status ModelManagerNDKImpl::PrepareModelManager(
    const ModelInitOptions &options, const std::shared_ptr<IBuiltModel> &builtModel)
{
    auto builtModelNDKImpl = std::dynamic_pointer_cast<BuiltModelNDKImpl>(builtModel);
    HIAI_EXPECT_NOT_NULL(builtModelNDKImpl);

    auto nnCompilation = builtModelNDKImpl->GetNNCompilation();
    HIAI_EXPECT_NOT_NULL(nnCompilation);

    // 如果加载选项有变化，需要重新加载
    auto executor = builtModelNDKImpl->GetExecutor();
    if (executor == nullptr || IsOptionsChanged(isAsyncMode_, options)) {
        HIAI_EXPECT_EXEC(SetInitOption(nnCompilation.get(), options, isAsyncMode_));
        executor_ = std::shared_ptr<OH_NNExecutor>(HIAI_NDK_NNExecutor_Construct(nnCompilation.get()),
            [](OH_NNExecutor* p) { HIAI_NDK_NNExecutor_Destroy(&p); });
        HIAI_EXPECT_NOT_NULL(executor_);
        builtModelNDKImpl->ReSetExecutor(executor_);
    } else {
        executor_ = executor;
    }
    nnCompilation_ = nnCompilation;
    return SUCCESS;
}

Status ModelManagerNDKImpl::Init(const ModelInitOptions &options, const std::shared_ptr<IBuiltModel> &builtModel,
    const std::shared_ptr<IModelManagerListener> &listener)
{
    HIAI_EXPECT_NOT_NULL(builtModel);
    std::lock_guard<std::mutex> lock(modelManagerMutex_);
    HIAI_EXPECT_TRUE(executor_ == nullptr);

    customModelData_ = builtModel->GetCustomData();

    isAsyncMode_ = (listener != nullptr);

    HIAI_EXPECT_EXEC(PrepareModelManager(options, builtModel));

    HIAI_EXPECT_EXEC(PrepareModelManagerListener(listener));
    return SUCCESS;
}

Status ModelManagerNDKImpl::Init(const ModelInitOptions &options, const std::shared_ptr<IBuiltModel> &builtModel,
    const std::shared_ptr<IModelManagerListener> &listener, const std::shared_ptr<ISharedMemAllocator> &allocator)
{
    HIAI_EXPECT_NOT_NULL(builtModel);
    std::lock_guard<std::mutex> lock(modelManagerMutex_);
    HIAI_EXPECT_TRUE(executor_ == nullptr);

    customModelData_ = builtModel->GetCustomData();

    HIAI_EXPECT_EXEC(PrepareSharedMemAllocator(allocator, builtModel));

    HIAI_EXPECT_EXEC(PrepareModelManager(options, builtModel));

    HIAI_EXPECT_EXEC(PrepareModelManagerListener(listener));
    return SUCCESS;
}

Status ModelManagerNDKImpl::InitWeights(const std::string& weightDir)
{
    HIAI_EXPECT_NOT_NULL(executor_);
    return HIAI_NDK_HiAIExecutor_InitWeights(executor_.get(), weightDir.c_str());
}

Status ModelManagerNDKImpl::GetWeightBuffer(const std::string& weightName, std::shared_ptr<IBuffer>& weightBuffer)
{
    HIAI_EXPECT_NOT_NULL(executor_);
    void* data = nullptr;
    size_t size = 0;
    Status ret = HIAI_NDK_HiAIExecutor_GetWeightBuffer(executor_.get(), weightName.c_str(), &data, &size);
    if (ret != SUCCESS) {
        FMK_LOGE("GetWeightBuffer failed.");
        return ret;
    }
    HIAI_EXPECT_NOT_NULL(data);
    HIAI_EXPECT_TRUE(size > 0);
    weightBuffer = CreateLocalBuffer(data, size, false);

    return SUCCESS;
}

Status ModelManagerNDKImpl::FlushWeight(const std::string& weightName, size_t offset, size_t size)
{
    HIAI_EXPECT_NOT_NULL(executor_);
    return HIAI_NDK_HiAIExecutor_FlushWeight(executor_.get(), weightName.c_str(), offset, size);
}

void ModelManagerNDKImpl::OnRunDone(
    const Context &context, Status errCode, std::vector<std::shared_ptr<INDTensorBuffer>> &outputs)
{
    std::lock_guard<std::mutex> lock(listenerMutex_);

    if (listener_ != nullptr) {
        listener_->OnRunDone(context, errCode, outputs);
    }
}

void ModelManagerNDKImpl::OnServiceDied()
{
    std::lock_guard<std::mutex> lock(listenerMutex_);

    if (listener_ != nullptr) {
        listener_->OnServiceDied();
    }
}

void ModelManagerNDKImpl::OnRunDone(void* userData, OH_NN_ReturnCode errCode, void* outputTensor[], int32_t outputCount)
{
    /* third param: output */
    /* fourth param: outputNum */
    (void)outputTensor;
    (void)outputCount;
    RunAsyncContext* runAsyncContext = (RunAsyncContext *)userData;
    HIAI_EXPECT_NOT_NULL_VOID(runAsyncContext);
    HIAI_EXPECT_NOT_NULL_VOID(runAsyncContext->modelManager);
    runAsyncContext->modelManager->OnRunDone(
        runAsyncContext->context, static_cast<Status>(errCode), runAsyncContext->outputs);
    delete runAsyncContext;
}

void ModelManagerNDKImpl::OnServiceDied(void* userData)
{
    ModelManagerNDKImpl* impl = (ModelManagerNDKImpl *)userData;
    HIAI_EXPECT_NOT_NULL_VOID(impl);
    impl->OnServiceDied();
}

Status ModelManagerNDKImpl::SetPriority(ModelPriority priority)
{
    std::lock_guard<std::mutex> lock(modelManagerMutex_);
    return HIAI_NDK_NNCompilation_SetPriority(nnCompilation_.get(), static_cast<HIAI_ModelPriority>(priority));
}

uint32_t ModelManagerNDKImpl::GetModelID()
{
    std::lock_guard<std::mutex> lock(modelManagerMutex_);
    uint32_t modelId = 0;
    (void)HIAI_NDK_HiAIExecutor_GetModelID(executor_.get(), &modelId);
    return modelId;
}

void ModelManagerNDKImpl::OnAllocate(void* userData, uint32_t requiredSize,
    HiAI_NativeHandle* handles[], size_t* handlesSize)
{
    HIAI_EXPECT_TRUE_VOID(userData != nullptr && handles != nullptr && handlesSize != nullptr);
    *handlesSize = 0;

    MemAllocaterContext* context = (MemAllocaterContext*)userData;
    HIAI_EXPECT_TRUE_VOID(context->modelManager != nullptr);
    std::shared_ptr<ISharedMemAllocator> memAllocator = context->modelManager->allocator_;
    HIAI_EXPECT_NOT_NULL_VOID(memAllocator);

    std::vector<hiai::NativeHandle> native = memAllocator->Allocate(requiredSize);
    if (native.size() > 10) { // 10: native handle max
        FMK_LOGE("Allocate client mem failed, size = %u", native.size());
        memAllocator->Free(native);
        return;
    }

    for (uint32_t i = 0; i < native.size(); i++) {
        handles[i] = HIAI_NDK_HiAINativeHandle_Create(native[i].fd, native[i].size, native[i].offset);
        if (handles[i] == nullptr) {
            for (uint32_t j = 0; j < i; j++) {
                HIAI_NDK_HiAINativeHandle_Destroy(&handles[j]);
            }
            context->modelManager->nativeHandle_.clear();
            return;
        }
        context->modelManager->nativeHandle_.push_back(std::make_pair(handles[i], native[i]));
    }

    *handlesSize = native.size();
}

void ModelManagerNDKImpl::OnFree(void* userData, HiAI_NativeHandle* handles[], size_t handlesSize)
{
    HIAI_EXPECT_TRUE_VOID(userData != nullptr && handles != nullptr && handlesSize != 0);

    MemAllocaterContext* context = (MemAllocaterContext*)userData;
    HIAI_EXPECT_TRUE_VOID(context->modelManager != nullptr);
    auto& nativeHandles = context->modelManager->nativeHandle_;
    HIAI_EXPECT_TRUE_VOID(handlesSize == nativeHandles.size());

    std::vector<NativeHandle> native;
    for (size_t i = 0; i < handlesSize; i++) {
        HIAI_NDK_HiAINativeHandle_Destroy(&nativeHandles[i].first);
        native.push_back(nativeHandles[i].second);
    }

    std::shared_ptr<ISharedMemAllocator> memAllocator = context->modelManager->allocator_;
    HIAI_EXPECT_NOT_NULL_VOID(memAllocator);
    memAllocator->Free(native);
}

Status ModelManagerNDKImpl::PrepareSharedMemAllocator(const std::shared_ptr<ISharedMemAllocator>& allocator, const std::shared_ptr<IBuiltModel>& builtModel)
{
    HIAI_EXPECT_NOT_NULL(allocator);
    auto builtModelNDKImpl = std::dynamic_pointer_cast<BuiltModelNDKImpl>(builtModel);

    HIAI_EXPECT_NOT_NULL(builtModelNDKImpl);
    auto nnCompilation = builtModelNDKImpl->GetNNCompilation();
    HIAI_EXPECT_NOT_NULL(nnCompilation);
    Status status = HIAI_NDK_HiAIOptions_SetAllocate(nnCompilation.get(), OnAllocate);
    HIAI_EXPECT_TRUE_R(status == SUCCESS, FAILURE);
    status = HIAI_NDK_HiAIOptions_SetFree(nnCompilation.get(), OnFree);
    HIAI_EXPECT_TRUE_R(status == SUCCESS, FAILURE);

    context_ = new (std::nothrow) MemAllocaterContext;
    HIAI_EXPECT_NOT_NULL(context_);
    context_->modelManager = this;
    status = HIAI_NDK_HiAIOptions_SetUserData(nnCompilation.get(), context_);
    if (status == FAILURE) {
        delete context_;
        context_ = nullptr;
        return FAILURE;
    }
    allocator_ = allocator;
    return SUCCESS;
}

namespace {
NN_Tensor* GetRawBufferFromNDTensorBuffer(const std::shared_ptr<INDTensorBuffer> &buffer)
{
    std::shared_ptr<NDTensorBufferNDKImpl> bufferNDKImpl = std::dynamic_pointer_cast<NDTensorBufferNDKImpl>(buffer);
    if (bufferNDKImpl == nullptr) {
        FMK_LOGE("invalid buffer");
        return nullptr;
    }
    return bufferNDKImpl->GetNNTensor();
}

std::vector<NN_Tensor *> Convert2NNTensorBuffers(const std::vector<std::shared_ptr<INDTensorBuffer>> &buffers)
{
    std::vector<NN_Tensor *> nnBuffers;
    for (size_t i = 0; i < buffers.size(); i++) {
        HIAI_EXPECT_NOT_NULL_R(buffers[i], nnBuffers);
        NN_Tensor* nnTensor = GetRawBufferFromNDTensorBuffer(buffers[i]);
        HIAI_EXPECT_NOT_NULL_R(nnTensor, nnBuffers);
        nnBuffers.push_back(nnTensor);
    }
    return nnBuffers;
}
}  // namespace

Status ModelManagerNDKImpl::Run(
    const std::vector<std::shared_ptr<INDTensorBuffer>> &inputs, std::vector<std::shared_ptr<INDTensorBuffer>> &outputs)
{
    if (!customModelData_.type.empty()) {
        std::vector<std::shared_ptr<INDTensorBuffer>> dataInputs;
        std::vector<std::shared_ptr<IAIPPPara>> paraInputs;

        if (AippInputConverter::ConvertInputs(inputs, customModelData_, dataInputs, paraInputs) != hiai::SUCCESS) {
            return INVALID_PARAM;
        }
        Context context;
        return RunAippModel(context, dataInputs, paraInputs, outputs, 1000);
    }

    std::vector<NN_Tensor *> nnInputs = Convert2NNTensorBuffers(inputs);
    HIAI_EXPECT_TRUE_R(nnInputs.size() == inputs.size(), INVALID_PARAM);

    std::vector<NN_Tensor *> nnOutputs = Convert2NNTensorBuffers(outputs);
    HIAI_EXPECT_TRUE_R(nnOutputs.size() == outputs.size(), INVALID_PARAM);

    std::lock_guard<std::mutex> lock(modelManagerMutex_);
    return HIAI_NDK_NNExecutor_RunSync(executor_.get(), nnInputs, nnOutputs);
}

Status ModelManagerNDKImpl::RunAsync(const Context &context,
    const std::vector<std::shared_ptr<INDTensorBuffer>> &inputs, std::vector<std::shared_ptr<INDTensorBuffer>> &outputs,
    int32_t timeout)
{
    std::vector<NN_Tensor *> nnInputs = Convert2NNTensorBuffers(inputs);
    HIAI_EXPECT_TRUE_R(nnInputs.size() == inputs.size(), INVALID_PARAM);

    std::vector<NN_Tensor *> nnOutputs = Convert2NNTensorBuffers(outputs);
    HIAI_EXPECT_TRUE_R(nnOutputs.size() == outputs.size(), INVALID_PARAM);

    std::lock_guard<std::mutex> lock(modelManagerMutex_);

    HIAI_EXPECT_NOT_NULL_R(listener_, UNSUPPORTED);

    RunAsyncContext* runContext = new (std::nothrow) RunAsyncContext();
    HIAI_EXPECT_NOT_NULL_R(runContext, MEMORY_EXCEPTION);

    runContext->context = context;
    runContext->modelManager = this;
    runContext->outputs = outputs;

    Status ret = HIAI_NDK_NNExecutor_RunAsync(executor_.get(), nnInputs, nnOutputs, timeout, runContext);
    if (ret != SUCCESS) {
        delete runContext;
    }

    return ret;
}

#ifdef AI_SUPPORT_AIPP_API
namespace {
std::vector<HiAI_AippParam *> Convert2NNAippParams(const std::vector<std::shared_ptr<IAIPPPara>> &aippParas)
{
    std::vector<HiAI_AippParam *> nnAippParams;
    for (size_t i = 0; i < aippParas.size(); i++) {
        HiAI_AippParam* nnAippParam = GetNNTensorAippParaFromAippPara(aippParas[i]);
        HIAI_EXPECT_NOT_NULL_R(nnAippParam, nnAippParams);
        nnAippParams.push_back(nnAippParam);
    }
    return nnAippParams;
}
}  // namespace

Status ModelManagerNDKImpl::RunAippModel(const Context &context,
    const std::vector<std::shared_ptr<INDTensorBuffer>> &inputs,
    const std::vector<std::shared_ptr<IAIPPPara>> &aippParas, std::vector<std::shared_ptr<INDTensorBuffer>> &outputs,
    int32_t timeoutInMS)
{
    std::vector<NN_Tensor *> nnInputs = Convert2NNTensorBuffers(inputs);
    HIAI_EXPECT_TRUE_R(nnInputs.size() == inputs.size(), INVALID_PARAM);

    std::vector<NN_Tensor *> nnOutputs = Convert2NNTensorBuffers(outputs);
    HIAI_EXPECT_TRUE_R(nnOutputs.size() == outputs.size(), INVALID_PARAM);

    std::vector<HiAI_AippParam *> nnAippParams = Convert2NNAippParams(aippParas);
    HIAI_EXPECT_TRUE_R(nnAippParams.size() == aippParas.size(), INVALID_PARAM);

    Status status = HIAI_NDK_HiAITensor_SetAippParams(nnInputs, nnAippParams);
    HIAI_EXPECT_TRUE_R(status == SUCCESS, FAILURE);

    std::lock_guard<std::mutex> lock(modelManagerMutex_);

    if (listener_ == nullptr) {  // 同步aipp推理
        return HIAI_NDK_NNExecutor_RunSync(executor_.get(), nnInputs, nnOutputs);
    }

    RunAsyncContext* runContext = new (std::nothrow) RunAsyncContext();
    HIAI_EXPECT_NOT_NULL_R(runContext, MEMORY_EXCEPTION);

    runContext->context = context;
    runContext->modelManager = this;
    runContext->outputs = outputs;

    // 异步aipp推理
    Status ret = HIAI_NDK_NNExecutor_RunAsync(executor_.get(), nnInputs, nnOutputs, timeoutInMS, runContext);
    if (ret != SUCCESS) {
        delete runContext;
    }
    return ret;
}
#endif

Status ModelManagerNDKImpl::Cancel()
{
    // 接口不支持
    FMK_LOGE("Not supported.");
    return FAILURE;
}

void ModelManagerNDKImpl::UnLoad()
{
    std::lock_guard<std::mutex> lock(modelManagerMutex_);
    executor_.reset();
    executor_ = nullptr;
    if (context_ != nullptr) {
        delete context_;
        context_ = nullptr;
    }
}

void ModelManagerNDKImpl::DeInit()
{
    UnLoad();
    {
        std::lock_guard<std::mutex> lock(listenerMutex_);
        listener_ = nullptr;
    }
}

HIAI_MM_API_EXPORT std::shared_ptr<IModelManager> CreateModelManagerFromNDK()
{
    return make_shared_nothrow<ModelManagerNDKImpl>();
}

HIAI_MM_API_EXPORT std::shared_ptr<IModelManagerExt> CreateModelManagerExtFromNDK()
{
    return make_shared_nothrow<ModelManagerNDKImpl>();
}

}  // namespace hiai